{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54ecf0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "# HuggingFace Tutorial Series\n",
    "- 1. What is Huggingface?\n",
    "- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
    "- 3. Using the HuggingFace Pipeline (High level feature)\n",
    "- 4. How the pipeline works at a lower level\n",
    "- 5. HuggingFace Datasets\n",
    "- 6. HuggingFace Tokenizer\n",
    "- 7. HuggingFace Evaluate\n",
    "- 8. HuggingFace Trainer\n",
    "- 9. Putting it together to finetune a news article summarizer\n",
    "- 10. Making it more general and robust with Lightning and custom data loading\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec1aae37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import datasets \n",
    "import pytorch_lightning as pl\n",
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "from transformers import (\n",
    "    AutoModel,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForSeq2Seq,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    Seq2SeqTrainer,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "torch.set_float32_matmul_precision(\"medium\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fd7cb0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"t5-small\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "418cb03a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class cnn_dailymail(Dataset):\n",
    "    def __init__(self, csv_file, tokenizer, max_length=512):\n",
    "        self.data = pd.read_csv(csv_file)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        article = self.data.loc[idx, 'article']\n",
    "        highlights = self.data.loc[idx, 'highlights']\n",
    "\n",
    "        inputs = self.tokenizer(\n",
    "            article,\n",
    "            truncation=True,\n",
    "            padding='max_length',\n",
    "            max_length=self.max_length,\n",
    "            return_tensors='pt'\n",
    "        )\n",
    "        targets = self.tokenizer(\n",
    "            highlights,\n",
    "            truncation=True,\n",
    "            padding='max_length',\n",
    "            max_length=self.max_length,\n",
    "            return_tensors='pt'\n",
    "        )\n",
    "\n",
    "        return {\n",
    "            'input_ids': inputs['input_ids'].squeeze(),\n",
    "            'attention_mask': inputs['attention_mask'].squeeze(),\n",
    "            'labels': targets['input_ids'].squeeze()\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaa62755",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataModule(pl.LightningDataModule):\n",
    "    def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
    "        super().__init__()\n",
    "        self.train_csv = train_csv\n",
    "        self.val_csv = val_csv\n",
    "        self.test_csv = test_csv\n",
    "        self.tokenizer = tokenizer\n",
    "        self.batch_size = batch_size\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def setup(self, stage=None):\n",
    "        if stage in ('fit', None):\n",
    "            self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
    "            self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
    "        if stage in ('test', None):\n",
    "            self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbb699e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLightningModule(pl.LightningModule):\n",
    "    def __init__(self, model_name, learning_rate, weight_decay):\n",
    "        super().__init__()\n",
    "        self.model_name = model_name\n",
    "        self.learning_rate = learning_rate\n",
    "        self.weight_decay = weight_decay\n",
    "        \n",
    "        # Load the pre-trained model and tokenizer\n",
    "        self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
    "        \n",
    "        # Load the ROUGE metric\n",
    "        self.metric = load_metric(\"rouge\")\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        output = self.model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels,\n",
    "        )\n",
    "        return output.loss, output.logits\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        input_ids = batch[\"input_ids\"]\n",
    "        attention_mask = batch[\"attention_mask\"]\n",
    "        labels = batch[\"labels\"]\n",
    "        loss, logits = self(input_ids, attention_mask, labels)\n",
    "        self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
    "        return {'loss': loss, 'logits': logits}\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        input_ids = batch[\"input_ids\"]\n",
    "        attention_mask = batch[\"attention_mask\"]\n",
    "        labels = batch[\"labels\"]\n",
    "        loss, logits = self(input_ids, attention_mask, labels)\n",
    "        self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
    "        \n",
    "        # Save logits and labels as instance attributes\n",
    "        if not hasattr(self, \"logits\"):\n",
    "            self.logits = logits\n",
    "        else:\n",
    "            self.logits = torch.cat((self.logits, logits), dim=0)\n",
    "        \n",
    "        if not hasattr(self, \"labels\"):\n",
    "            self.labels = labels\n",
    "        else:\n",
    "            self.labels = torch.cat((self.labels, labels), dim=0)\n",
    "            \n",
    "        return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
    "    \n",
    "    def on_validation_epoch_end(self):\n",
    "        # Convert logits to predicted token IDs\n",
    "        pred_token_ids = self.logits.argmax(dim=-1)\n",
    "\n",
    "        # Decode predictions and labels using the saved instance attributes\n",
    "        decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
    "        decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
    "\n",
    "        # Compute ROUGE scores\n",
    "        scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
    "\n",
    "        self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
    "        self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
    "        self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
    "\n",
    "        # Clear logits and labels instance attributes for the next validation epoch\n",
    "        del self.logits\n",
    "        del self.labels\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
    "        return optimizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd63c628",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# File paths\n",
    "train_csv = \"train.csv\"\n",
    "val_csv = \"validation.csv\"\n",
    "test_csv = \"test.csv\"\n",
    "\n",
    "# Create the data module\n",
    "dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
    "dm.setup()\n",
    "\n",
    "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
    "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
    "trainer.fit(model, datamodule=dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5d3d684",
   "metadata": {},
   "outputs": [],
   "source": [
    "http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0494596",
   "metadata": {},
   "source": [
    "### next steps:\n",
    "* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
    "\n",
    "#### what we've done:\n",
    "* Change the data loading so it's more general, meaning on the fly loading from disk\n",
    "* add torch.compile\n",
    "* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
    "* add tensorboard visualization\n",
    "* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
    "* 2. Create an inference step, send in news article -> get summary, check that it works\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a2efab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9b71ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
